#ifndef OPENMC_RANDOM_RAY_SOURCE_REGION_H
#define OPENMC_RANDOM_RAY_SOURCE_REGION_H

#include "openmc/openmp_interface.h"
#include "openmc/position.h"
#include "openmc/random_ray/moment_matrix.h"
#include "openmc/settings.h"

namespace openmc {

//----------------------------------------------------------------------------
// Helper Functions

// The hash_combine function is the standard hash combine function from boost
// that is typically used for combining multiple hash values into a single hash
// as is needed for larger objects being stored in a hash map. The function is
// taken from:
// https://www.boost.org/doc/libs/1_55_0/doc/html/hash/reference.html#boost.hash_combine
// which carries the following license:
//
// Boost Software License - Version 1.0 - August 17th, 2003
// Permission is hereby granted, free of charge, to any person or organization
// obtaining a copy of the software and accompanying documentation covered by
// this license (the "Software") to use, reproduce, display, distribute,
// execute, and transmit the Software, and to prepare derivative works of the
// Software, and to permit third-parties to whom the Software is furnished to
// do so, all subject to the following:
// The copyright notices in the Software and this entire statement, including
// the above license grant, this restriction and the following disclaimer,
// must be included in all copies of the Software, in whole or in part, and
// all derivative works of the Software, unless such copies or derivative
// works are solely in the form of machine-executable object code generated by
// a source language processor.
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT
// SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE
// FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE,
// ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.
inline void hash_combine(size_t& seed, const size_t v)
{
  seed ^= (v + 0x9e3779b9 + (seed << 6) + (seed >> 2));
}

//----------------------------------------------------------------------------
// Helper Structs and Classes

// A mapping object that is used to map between a specific random ray
// source region and an OpenMC native tally bin that it should score to
// every iteration.
struct TallyTask {
  int tally_idx;
  int filter_idx;
  int score_idx;
  int score_type;
  TallyTask(int tally_idx, int filter_idx, int score_idx, int score_type)
    : tally_idx(tally_idx), filter_idx(filter_idx), score_idx(score_idx),
      score_type(score_type)
  {}
  TallyTask() = default;

  // Comparison and Hash operators are defined to allow usage of the
  // TallyTask struct as a key in an unordered_set
  bool operator==(const TallyTask& other) const
  {
    return tally_idx == other.tally_idx && filter_idx == other.filter_idx &&
           score_idx == other.score_idx && score_type == other.score_type;
  }

  struct HashFunctor {
    size_t operator()(const TallyTask& task) const
    {
      size_t seed = 0;
      hash_combine(seed, task.tally_idx);
      hash_combine(seed, task.filter_idx);
      hash_combine(seed, task.score_idx);
      hash_combine(seed, task.score_type);
      return seed;
    }
  };
};

// The SourceRegionKey combines a base source region (i.e., a material
// filled cell instance) with a mesh bin. This key is used as a handle
// for dynamically discovered source regions when subdividing source
// regions with meshes.
class SourceRegionKey {
public:
  int64_t base_source_region_id;
  int64_t mesh_bin;
  SourceRegionKey() = default;
  SourceRegionKey(int64_t source_region, int64_t bin)
    : base_source_region_id(source_region), mesh_bin(bin)
  {}

  // Equality operator required by the unordered_map
  bool operator==(const SourceRegionKey& other) const
  {
    return base_source_region_id == other.base_source_region_id &&
           mesh_bin == other.mesh_bin;
  }

  // Less than operator required by std::sort
  bool operator<(const SourceRegionKey& other) const
  {
    if (base_source_region_id < other.base_source_region_id) {
      return true;
    } else if (base_source_region_id > other.base_source_region_id) {
      return false;
    } else {
      return mesh_bin < other.mesh_bin;
    }
  }

  // Hashing functor required by the unordered_map
  struct HashFunctor {
    size_t operator()(const SourceRegionKey& key) const
    {
      size_t seed = 0;
      hash_combine(seed, key.base_source_region_id);
      hash_combine(seed, key.mesh_bin);
      return seed;
    }
  };
};

// Forward declaration of SourceRegion
class SourceRegion;

class SourceRegionHandle {
public:
  //----------------------------------------------------------------------------
  // Constructors
  SourceRegionHandle(SourceRegion& sr);
  SourceRegionHandle() = default;

  // All fields are commented/described in the SourceRegion class definition
  // below

  //----------------------------------------------------------------------------
  // Public Data members
  int negroups_;
  bool is_numerical_fp_artifact_ {false};
  bool is_linear_ {false};

  // Scalar fields
  int* material_;
  int* is_small_;
  int* n_hits_;
  int* birthday_;
  OpenMPMutex* lock_;
  double* volume_;
  double* volume_t_;
  double* volume_sq_;
  double* volume_sq_t_;
  double* volume_naive_;
  int* position_recorded_;
  int* external_source_present_;
  Position* position_;
  Position* centroid_;
  Position* centroid_iteration_;
  Position* centroid_t_;
  MomentMatrix* mom_matrix_;
  MomentMatrix* mom_matrix_t_;
  // A set of volume tally tasks. This more complicated data structure is
  // convenient for ensuring that volumes are only tallied once per source
  // region, regardless of how many energy groups are used for tallying.
  std::unordered_set<TallyTask, TallyTask::HashFunctor>* volume_task_;

  // Mesh that subdivides this source region
  int* mesh_;
  int64_t* parent_sr_;

  // Energy group-wise 1D arrays
  double* scalar_flux_old_;
  double* scalar_flux_new_;
  float* source_;
  float* external_source_;
  double* scalar_flux_final_;

  MomentArray* source_gradients_;
  MomentArray* flux_moments_old_;
  MomentArray* flux_moments_new_;
  MomentArray* flux_moments_t_;

  // 2D array representing values for all energy groups x tally
  // tasks. Each group may have a different number of tally tasks
  // associated with it, necessitating the use of a jagged array.
  vector<TallyTask>* tally_task_;

  //----------------------------------------------------------------------------
  // Public Accessors

  int& material() { return *material_; }
  const int material() const { return *material_; }

  int& is_small() { return *is_small_; }
  const int is_small() const { return *is_small_; }

  int& n_hits() { return *n_hits_; }
  const int n_hits() const { return *n_hits_; }

  void lock() { lock_->lock(); }
  void unlock() { lock_->unlock(); }

  double& volume() { return *volume_; }
  const double volume() const { return *volume_; }

  double& volume_t() { return *volume_t_; }
  const double volume_t() const { return *volume_t_; }

  double& volume_sq() { return *volume_sq_; }
  const double volume_sq() const { return *volume_sq_; }

  double& volume_sq_t() { return *volume_sq_t_; }
  const double volume_sq_t() const { return *volume_sq_t_; }

  double& volume_naive() { return *volume_naive_; }
  const double volume_naive() const { return *volume_naive_; }

  int& position_recorded() { return *position_recorded_; }
  const int position_recorded() const { return *position_recorded_; }

  int& external_source_present() { return *external_source_present_; }
  const int external_source_present() const
  {
    return *external_source_present_;
  }

  Position& position() { return *position_; }
  const Position position() const { return *position_; }

  Position& centroid() { return *centroid_; }
  const Position centroid() const { return *centroid_; }

  Position& centroid_iteration() { return *centroid_iteration_; }
  const Position centroid_iteration() const { return *centroid_iteration_; }

  Position& centroid_t() { return *centroid_t_; }
  const Position centroid_t() const { return *centroid_t_; }

  MomentMatrix& mom_matrix() { return *mom_matrix_; }
  const MomentMatrix mom_matrix() const { return *mom_matrix_; }

  MomentMatrix& mom_matrix_t() { return *mom_matrix_t_; }
  const MomentMatrix mom_matrix_t() const { return *mom_matrix_t_; }

  std::unordered_set<TallyTask, TallyTask::HashFunctor>& volume_task()
  {
    return *volume_task_;
  }
  const std::unordered_set<TallyTask, TallyTask::HashFunctor>& volume_task()
    const
  {
    return *volume_task_;
  }

  int& mesh() { return *mesh_; }
  const int mesh() const { return *mesh_; }

  int64_t& parent_sr() { return *parent_sr_; }
  const int64_t parent_sr() const { return *parent_sr_; }

  double& scalar_flux_old(int g) { return scalar_flux_old_[g]; }
  const double scalar_flux_old(int g) const { return scalar_flux_old_[g]; }

  double& scalar_flux_new(int g) { return scalar_flux_new_[g]; }
  const double scalar_flux_new(int g) const { return scalar_flux_new_[g]; }

  double& scalar_flux_final(int g) { return scalar_flux_final_[g]; }
  const double scalar_flux_final(int g) const { return scalar_flux_final_[g]; }

  float& source(int g) { return source_[g]; }
  const float source(int g) const { return source_[g]; }

  float& external_source(int g) { return external_source_[g]; }
  const float external_source(int g) const { return external_source_[g]; }

  MomentArray& source_gradients(int g) { return source_gradients_[g]; }
  const MomentArray source_gradients(int g) const
  {
    return source_gradients_[g];
  }

  MomentArray& flux_moments_old(int g) { return flux_moments_old_[g]; }
  const MomentArray flux_moments_old(int g) const
  {
    return flux_moments_old_[g];
  }

  MomentArray& flux_moments_new(int g) { return flux_moments_new_[g]; }
  const MomentArray flux_moments_new(int g) const
  {
    return flux_moments_new_[g];
  }

  MomentArray& flux_moments_t(int g) { return flux_moments_t_[g]; }
  const MomentArray flux_moments_t(int g) const { return flux_moments_t_[g]; }

  vector<TallyTask>& tally_task(int g) { return tally_task_[g]; }
  const vector<TallyTask>& tally_task(int g) const { return tally_task_[g]; }

}; // class SourceRegionHandle

class SourceRegion {
public:
  //----------------------------------------------------------------------------
  // Constructors
  SourceRegion(int negroups, bool is_linear);
  SourceRegion(const SourceRegionHandle& handle, int64_t parent_sr);
  SourceRegion() = default;

  //----------------------------------------------------------------------------
  // Public Data members

  //---------------------------------------
  // Scalar fields

  int material_ {0}; //!< Index in openmc::model::materials array
  OpenMPMutex lock_;
  double volume_ {
    0.0}; //!< Volume (computed from the sum of ray crossing lengths)
  double volume_t_ {0.0};     //!< Volume totaled over all iterations
  double volume_sq_ {0.0};    //!< Volume squared
  double volume_sq_t_ {0.0};  //!< Volume squared totaled over all iterations
  double volume_naive_ {0.0}; //!< Volume as integrated from this iteration only
  int position_recorded_ {0}; //!< Has the position been recorded yet?
  int external_source_present_ {
    0};               //!< Is an external source present in this region?
  int is_small_ {0};  //!< Is it "small", receiving < 1.5 hits per iteration?
  int n_hits_ {0};    //!< Number of total hits (ray crossings)
                      // Mesh that subdivides this source region
  int mesh_ {C_NONE}; //!< Index in openmc::model::meshes array that subdivides
                      //!< this source region
  int64_t parent_sr_ {C_NONE}; //!< Index of a parent source region
  Position position_ {
    0.0, 0.0, 0.0}; //!< A position somewhere inside the region
  Position centroid_ {0.0, 0.0, 0.0}; //!< The centroid
  Position centroid_iteration_ {
    0.0, 0.0, 0.0}; //!< The centroid integrated from this iteration only
  Position centroid_t_ {
    0.0, 0.0, 0.0}; //!< The centroid accumulated over all iterations
  MomentMatrix mom_matrix_ {
    0.0, 0.0, 0.0, 0.0, 0.0, 0.0}; //!< The spatial moment matrix
  MomentMatrix mom_matrix_t_ {0.0, 0.0, 0.0, 0.0, 0.0,
    0.0}; //!< The spatial moment matrix accumulated over all iterations

  // A set of volume tally tasks. This more complicated data structure is
  // convenient for ensuring that volumes are only tallied once per source
  // region, regardless of how many energy groups are used for tallying.
  std::unordered_set<TallyTask, TallyTask::HashFunctor> volume_task_;

  //---------------------------------------
  // Energy group-wise 1D arrays

  vector<double>
    scalar_flux_old_; //!< The scalar flux from the previous iteration
  vector<double>
    scalar_flux_new_; //!< The scalar flux from the current iteration
  vector<float>
    source_; //!< The total source term (fission + scattering + external)
  vector<float> external_source_;    //!< The external source term
  vector<double> scalar_flux_final_; //!< The scalar flux accumulated over all
                                     //!< active iterations (used for plotting,
                                     //!< or computing adjoint sources)

  vector<MomentArray> source_gradients_; //!< The linear source gradients
  vector<MomentArray>
    flux_moments_old_; //!< The linear flux moments from the previous iteration
  vector<MomentArray>
    flux_moments_new_; //!< The linear flux moments from the current iteration
  vector<MomentArray>
    flux_moments_t_; //!< The linear flux moments accumulated over all active
                     //!< iterations (used for plotting)

  //---------------------------------------
  // 2D array representing values for all energy groups x tally
  // tasks. Each group may have a different number of tally tasks
  // associated with it, necessitating the use of a jagged array.
  vector<vector<TallyTask>> tally_task_;
}; // class SourceRegion

class SourceRegionContainer {
public:
  //----------------------------------------------------------------------------
  // Constructors
  SourceRegionContainer(int negroups, bool is_linear)
    : negroups_(negroups), is_linear_(is_linear)
  {}
  SourceRegionContainer() = default;

  //----------------------------------------------------------------------------
  // Public Accessors
  int& material(int64_t sr) { return material_[sr]; }
  const int material(int64_t sr) const { return material_[sr]; }

  int& is_small(int64_t sr) { return is_small_[sr]; }
  const int is_small(int64_t sr) const { return is_small_[sr]; }

  int& n_hits(int64_t sr) { return n_hits_[sr]; }
  const int n_hits(int64_t sr) const { return n_hits_[sr]; }

  OpenMPMutex& lock(int64_t sr) { return lock_[sr]; }
  const OpenMPMutex& lock(int64_t sr) const { return lock_[sr]; }

  double& volume(int64_t sr) { return volume_[sr]; }
  const double volume(int64_t sr) const { return volume_[sr]; }

  double& volume_t(int64_t sr) { return volume_t_[sr]; }
  const double volume_t(int64_t sr) const { return volume_t_[sr]; }

  double& volume_sq(int64_t sr) { return volume_sq_[sr]; }
  const double volume_sq(int64_t sr) const { return volume_sq_[sr]; }

  double& volume_sq_t(int64_t sr) { return volume_sq_t_[sr]; }
  const double volume_sq_t(int64_t sr) const { return volume_sq_t_[sr]; }

  double& volume_naive(int64_t sr) { return volume_naive_[sr]; }
  const double volume_naive(int64_t sr) const { return volume_naive_[sr]; }

  int& position_recorded(int64_t sr) { return position_recorded_[sr]; }
  const int position_recorded(int64_t sr) const
  {
    return position_recorded_[sr];
  }

  int& external_source_present(int64_t sr)
  {
    return external_source_present_[sr];
  }
  const int external_source_present(int64_t sr) const
  {
    return external_source_present_[sr];
  }

  Position& position(int64_t sr) { return position_[sr]; }
  const Position position(int64_t sr) const { return position_[sr]; }

  Position& centroid(int64_t sr) { return centroid_[sr]; }
  const Position centroid(int64_t sr) const { return centroid_[sr]; }

  Position& centroid_iteration(int64_t sr) { return centroid_iteration_[sr]; }
  const Position centroid_iteration(int64_t sr) const
  {
    return centroid_iteration_[sr];
  }

  Position& centroid_t(int64_t sr) { return centroid_t_[sr]; }
  const Position centroid_t(int64_t sr) const { return centroid_t_[sr]; }

  MomentMatrix& mom_matrix(int64_t sr) { return mom_matrix_[sr]; }
  const MomentMatrix mom_matrix(int64_t sr) const { return mom_matrix_[sr]; }

  MomentMatrix& mom_matrix_t(int64_t sr) { return mom_matrix_t_[sr]; }
  const MomentMatrix mom_matrix_t(int64_t sr) const
  {
    return mom_matrix_t_[sr];
  }

  MomentArray& source_gradients(int64_t sr, int g)
  {
    return source_gradients_[index(sr, g)];
  }
  const MomentArray source_gradients(int64_t sr, int g) const
  {
    return source_gradients_[index(sr, g)];
  }
  MomentArray& source_gradients(int64_t se) { return source_gradients_[se]; }
  const MomentArray source_gradients(int64_t se) const
  {
    return source_gradients_[se];
  }

  MomentArray& flux_moments_old(int64_t sr, int g)
  {
    return flux_moments_old_[index(sr, g)];
  }
  const MomentArray flux_moments_old(int64_t sr, int g) const
  {
    return flux_moments_old_[index(sr, g)];
  }
  MomentArray& flux_moments_old(int64_t se) { return flux_moments_old_[se]; }
  const MomentArray flux_moments_old(int64_t se) const
  {
    return flux_moments_old_[se];
  }

  MomentArray& flux_moments_new(int64_t sr, int g)
  {
    return flux_moments_new_[index(sr, g)];
  }
  const MomentArray flux_moments_new(int64_t sr, int g) const
  {
    return flux_moments_new_[index(sr, g)];
  }
  MomentArray& flux_moments_new(int64_t se) { return flux_moments_new_[se]; }
  const MomentArray flux_moments_new(int64_t se) const
  {
    return flux_moments_new_[se];
  }

  MomentArray& flux_moments_t(int64_t sr, int g)
  {
    return flux_moments_t_[index(sr, g)];
  }
  const MomentArray flux_moments_t(int64_t sr, int g) const
  {
    return flux_moments_t_[index(sr, g)];
  }
  MomentArray& flux_moments_t(int64_t se) { return flux_moments_t_[se]; }
  const MomentArray flux_moments_t(int64_t se) const
  {
    return flux_moments_t_[se];
  }

  double& scalar_flux_old(int64_t sr, int g)
  {
    return scalar_flux_old_[index(sr, g)];
  }
  const double scalar_flux_old(int64_t sr, int g) const
  {
    return scalar_flux_old_[index(sr, g)];
  }
  double& scalar_flux_old(int64_t se) { return scalar_flux_old_[se]; }
  const double scalar_flux_old(int64_t se) const
  {
    return scalar_flux_old_[se];
  }

  double& scalar_flux_new(int64_t sr, int g)
  {
    return scalar_flux_new_[index(sr, g)];
  }
  const double scalar_flux_new(int64_t sr, int g) const
  {
    return scalar_flux_new_[index(sr, g)];
  }
  double& scalar_flux_new(int64_t se) { return scalar_flux_new_[se]; }
  const double scalar_flux_new(int64_t se) const
  {
    return scalar_flux_new_[se];
  }

  double& scalar_flux_final(int64_t sr, int g)
  {
    return scalar_flux_final_[index(sr, g)];
  }
  const double scalar_flux_final(int64_t sr, int g) const
  {
    return scalar_flux_final_[index(sr, g)];
  }
  double& scalar_flux_final(int64_t se) { return scalar_flux_final_[se]; }
  const double scalar_flux_final(int64_t se) const
  {
    return scalar_flux_final_[se];
  }

  float& source(int64_t sr, int g) { return source_[index(sr, g)]; }
  const float source(int64_t sr, int g) const { return source_[index(sr, g)]; }
  float& source(int64_t se) { return source_[se]; }
  const float source(int64_t se) const { return source_[se]; }

  float& external_source(int64_t sr, int g)
  {
    return external_source_[index(sr, g)];
  }
  const float external_source(int64_t sr, int g) const
  {
    return external_source_[index(sr, g)];
  }
  float& external_source(int64_t se) { return external_source_[se]; }
  const float external_source(int64_t se) const { return external_source_[se]; }

  vector<TallyTask>& tally_task(int64_t sr, int g)
  {
    return tally_task_[index(sr, g)];
  }
  const vector<TallyTask>& tally_task(int64_t sr, int g) const
  {
    return tally_task_[index(sr, g)];
  }
  vector<TallyTask>& tally_task(int64_t se) { return tally_task_[se]; }
  const vector<TallyTask>& tally_task(int64_t se) const
  {
    return tally_task_[se];
  }

  std::unordered_set<TallyTask, TallyTask::HashFunctor>& volume_task(int64_t sr)
  {
    return volume_task_[sr];
  }
  const std::unordered_set<TallyTask, TallyTask::HashFunctor>& volume_task(
    int64_t sr) const
  {
    return volume_task_[sr];
  }

  int& mesh(int64_t sr) { return mesh_[sr]; }
  const int mesh(int64_t sr) const { return mesh_[sr]; }

  int64_t& parent_sr(int64_t sr) { return parent_sr_[sr]; }
  const int64_t parent_sr(int64_t sr) const { return parent_sr_[sr]; }

  //----------------------------------------------------------------------------
  // Public Methods

  void push_back(const SourceRegion& sr);
  void assign(int n_source_regions, const SourceRegion& source_region);
  void flux_swap();
  int64_t n_source_regions() const { return n_source_regions_; }
  int64_t n_source_elements() const { return n_source_regions_ * negroups_; }
  int& negroups() { return negroups_; }
  const int negroups() const { return negroups_; }
  bool& is_linear() { return is_linear_; }
  const bool is_linear() const { return is_linear_; }
  SourceRegionHandle get_source_region_handle(int64_t sr);
  void adjoint_reset();

private:
  //----------------------------------------------------------------------------
  // Private Data Members
  int64_t n_source_regions_ {0};
  int negroups_ {0};
  bool is_linear_ {false};

  // SoA storage for scalar fields (one item per source region)
  vector<int> material_;
  vector<int> is_small_;
  vector<int> n_hits_;
  vector<int> mesh_;
  vector<int64_t> parent_sr_;
  vector<OpenMPMutex> lock_;
  vector<double> volume_;
  vector<double> volume_t_;
  vector<double> volume_sq_;
  vector<double> volume_sq_t_;
  vector<double> volume_naive_;
  vector<int> position_recorded_;
  vector<int> external_source_present_;
  vector<Position> position_;
  vector<Position> centroid_;
  vector<Position> centroid_iteration_;
  vector<Position> centroid_t_;
  vector<MomentMatrix> mom_matrix_;
  vector<MomentMatrix> mom_matrix_t_;
  // A set of volume tally tasks. This more complicated data structure is
  // convenient for ensuring that volumes are only tallied once per source
  // region, regardless of how many energy groups are used for tallying.
  vector<std::unordered_set<TallyTask, TallyTask::HashFunctor>> volume_task_;

  // SoA energy group-wise 2D arrays flattened to 1D
  vector<double> scalar_flux_old_;
  vector<double> scalar_flux_new_;
  vector<double> scalar_flux_final_;
  vector<float> source_;
  vector<float> external_source_;

  vector<MomentArray> source_gradients_;
  vector<MomentArray> flux_moments_old_;
  vector<MomentArray> flux_moments_new_;
  vector<MomentArray> flux_moments_t_;

  // SoA 3D array representing values for all source regions x energy groups x
  // tally tasks. The outer two dimensions (source regions and energy groups)
  // are flattened to 1D. Each group may have a different number of tally tasks
  // associated with it, necessitating the use of a jagged array for the inner
  // dimension.
  vector<vector<TallyTask>> tally_task_;

  //----------------------------------------------------------------------------
  // Private Methods

  // Helper function for indexing
  inline int index(int64_t sr, int g) const { return sr * negroups_ + g; }
};

} // namespace openmc

#endif // OPENMC_RANDOM_RAY_SOURCE_REGION_H
